# Copyright (c) 2011-2012, Universite de Versailles St-Quentin-en-Yvelines
#
# This file is part of ASK.  ASK is free software: you can redistribute
# it and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation, version 2.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import os
import tempfile
from tree import *


def parse_cart_tree(lines, categories):
    """
    Parses CART trees generated by R module 'rpart'.

    lines: an array of string containing the output of the 'rpart' module
    categories: array that maps each factor to an the categorical values
    accepted by that factor.

    Returns the CART tree in tree.py format.
    """
    nodes = {}
    for l in lines[5:]:
        # read a node
        sp = l.split()
        number = int(sp[0][:-1])

        # Identify the root node
        if sp[1] == "root":
            node = {"prediction": float(sp[4])}
        else:
            # Detect the branch split condition
            if ">" in sp[1] or "<" in sp[1]:
                # Here we treat ordinal factors that are marked with  < or >=

                # detect the direction of the split condition
                # indeed the leaf that satisfies x >= split is not always the
                # right one.
                if ">" in sp[1]:
                    axis = sp[1].split(">=")
                    direction = 1
                else:
                    axis = sp[1].split("<")[0], sp[2]
                    direction = 0
                    sp.pop(2)

                node = {"axis": int(axis[0][1:]) - 1,
                        "cut": float(axis[1]),
                        "prediction": float(sp[4]),
                        "direction": direction}
            else:
                # Here we handle categorial factors that are marked with =
                axis = sp[1].split("=")
                node = {"axis": int(axis[0][1:]) - 1,
                        "choices": map(int, axis[1].split(",")),
                        "prediction": float(sp[4])}

        # Store the parsed node in the nodes table
        nodes[number] = node

    def get_tree(n):
        """
        This internal function returns the n-th tree in the 'rpart'
        representation. The tree is returned in our python tree format defined
        in tree.py. Rpart encodes the root as 1, the left child of n as 2*n,
        and the right child of n as 2*n+1.
        """
        #is it a leaf
        if 2 * n not in nodes:
            return Leaf(model=[nodes[n]["prediction"]])

        #get left and right children
        left = get_tree(2 * n)
        right = get_tree(2 * n + 1)

        if "choices" in nodes[2 * n]:
            lc = nodes[2 * n]["choices"]
            rc = nodes[2 * n + 1]["choices"]
            axis = nodes[2 * n]["axis"]

            chosen = set(lc + rc)
            allp = set(list(range(len(categories[axis]))))
            missing = allp - chosen
            rc += list(missing)

            # Build tree.py categorical Node
            return Node(left, right,
                        axis=axis,
                        cut=(lc,
                             rc),
                        model=[nodes[n]["prediction"]],
                        categorical=True)
        else:
            # Build tree.py ordinal Node
            # (the if is needed to normalize the direction, indeed in
            #  the tree.py representation the leaf satisfying x >= split is
            #  always the right one).
            if nodes[2 * n]["direction"] == 0:
                return Node(left, right,
                            axis=nodes[2 * n]["axis"],
                            cut=nodes[2 * n]["cut"],
                            model=[nodes[n]["prediction"]],
                            categorical=False)
            else:
                return Node(right, left,
                            axis=nodes[2 * n]["axis"],
                            cut=nodes[2 * n]["cut"],
                            model=[nodes[n]["prediction"]],
                            categorical=False)

    # Return the full tree in tree.py format.
    return get_tree(1)


def build_tree(conf, cp, input_file):
    """
    Returns a tree.py CART built from the input file data

    conf: configuration options
    cp: complexity parameter
    input_file: contains the data on which we apply CART regression, the last
    column is the response, the starting columns are the factors.
    """

    temp_dir = tempfile.mkdtemp()
    ti = os.path.join(temp_dir, "rcart.R")
    to = os.path.join(temp_dir, "out")
    categories = []
    rfile = open(ti, "w")

    # Prepare an R script that applies CART to the input data
    rfile.write('library(rpart)\n')
    rfile.write('data = read.table("{0}")\n'.
            format(input_file))

    for i, f in enumerate(conf["factors"]):
        if f["type"] == "categorical":
            rfile.write('data$V{0} = as.factor(data$V{1})\n'
                        .format(i + 1, i + 1))
            categories.append(f["values"])
        else:
            categories.append(f["range"])

    rfile.write("rtree = rpart(data$V{0} ~ . , data=data, "
                "control=rpart.control(cp={1}))\n".format(i + 2, cp))

    # Try to do cross validation pruning if enough samples are available
    rfile.write('if ("xerror" %in% colnames(rtree$cptable)) {')
    rfile.write(
            'bcp=rtree$cptable[which.min(rtree$cptable[,"xerror"]),"CP"]\n')
    rfile.write('rtree = prune(rtree, bcp)\n}\n')
    rfile.write("print(rtree)\n")
    rfile.close()

    # Launch R to build the CART tree
    print "Building CART model ..."
    os.system("R --slave --no-save < {0} > {1}".format(ti, to))

    # Parse and return the CART tree
    tf = open(to, "r")
    print "Parsing CART tree"
    T = parse_cart_tree(tf.readlines(), categories)
    tf.close()
    return T
